#include "database_mysql_connection.h"

#include <mysqld_error.h>
#include <errmsg.h>

#include "log.h"
#include "timer.h"
#include "utils.h"
#include "database_preparedstatement.h"
#include "database_preparedstatement_mysql.h"
#include "database_worker.h"
#include "database_mysql_wrap.h"
#include "database_mysql_head.h"
#include "database_query_result.h"
#include "database_transaction.h"

namespace afcore {
namespace database {

SMysqlConnectionInfo::SMysqlConnectionInfo(const std::string& info) {
  CTokenizer tokens(info, ';');

  if (5 != tokens.size()) {
    return;
  }

  uint8_t i = 0;
  host.assign(tokens[i++]);
  port_or_socket.assign(tokens[i++]);
  user.assign(tokens[i++]);
  password.assign(tokens[i++]);
  database.assign(tokens[i++]);
}

[[maybe_unused]] CMysqlConnection::CMysqlConnection(SMysqlConnectionInfo& conn_info)
  : conn_info_(conn_info)
  , conn_flags_(kConnectionFlags_Sync) {
}

CMysqlConnection::CMysqlConnection(SMysqlConnectionInfo& conn_info, CPcQueue<CSqlOperation*>* queue)
  : queue_(queue)
  , conn_info_(conn_info)
  , conn_flags_(kConnectionFlags_Async) {
  worker_ = std::make_unique<CDatabaseWorker>(queue_, this);
}

CMysqlConnection::~CMysqlConnection() {
  Close();
}

uint32_t CMysqlConnection::Open() {
  MYSQL* mysql_init_ptr;
  mysql_init_ptr = mysql_init(nullptr);
  if (!mysql_init_ptr) {
    LOG_ERROR("Could not initialize Mysql connection to database `{}`", conn_info_.database.c_str());
    return CR_UNKNOWN_ERROR;
  }

  int port = 0;
  const char* unix_socket = nullptr;

  mysql_options(mysql_init_ptr, MYSQL_SET_CHARSET_NAME, "utf8");
#if AFCORE_PLATFORM_WIN == AFCORE_PLATFORM
  if ("." == conn_info_.host) {
    unsigned int opt = MYSQL_PROTOCOL_PIPE;
    mysql_options(mysql_init_ptr, MYSQL_OPT_PROTOCOL, reinterpret_cast<const char*>(&opt));
    port = 0;
    unix_socket = nullptr;
  } else {
    port = strtol(conn_info_.port_or_socket.c_str(), nullptr, 10);
    unix_socket = nullptr;
  }
#else
  if ("." == conn_info_.host) {
    unsigned int opt = MYSQL_PROTOCOL_SOCKET;
    mysql_options(mysql_init_ptr, MYSQL_OPT_PROTOCOL, reinterpret_cast<const char*>(&opt));
    conn_info_.host = "localhost";
    port = 0;
    unix_socket = conn_info_.port_or_socket.c_str();
  } else {
    port = strtol(conn_info_.port_or_socket.c_str(), nullptr, 10);
    unix_socket = nullptr;
  }
#endif

  mysql_ = reinterpret_cast<SMysqlHandle*>(mysql_real_connect(mysql_init_ptr, conn_info_.host.c_str(), conn_info_.user.c_str(),
    conn_info_.password.c_str(), conn_info_.database.c_str(), port, unix_socket, 0));

  if(mysql_) {
    if (reconnecting_) {
      LOG_INFO("MySQL client library: {}", mysql_get_client_info());
      LOG_INFO("MySQL server version: {}", mysql_get_server_info(mysql_));
    }

    LOG_INFO("Connected to MySQL database at {}", conn_info_.host.c_str());
    mysql_autocommit(mysql_, 1);
    mysql_set_character_set(mysql_, "utf8");
    return 0;
  } else {
    LOG_ERROR("Could not connect to MySQL database at {}: {}", conn_info_.host.c_str(), mysql_error(mysql_init_ptr));
    mysql_close(mysql_init_ptr);
    return mysql_errno(mysql_init_ptr);
  }
}

void CMysqlConnection::Close() {
  worker_.reset();
  stmts_.clear();

  if (mysql_) {
    mysql_close(mysql_);
    mysql_ = nullptr;
  }
}

bool CMysqlConnection::PrepareStatements() {
  DoPreparedStatement();
  return !prepare_error_;
}

bool CMysqlConnection::Execute(const char *sql) {
  if (!mysql_) {
    return false;
  }

  {
    uint32_t time = GetMsTimeFromAppStart();

    if (mysql_query(mysql_, sql)) {
      uint32_t lerrno = mysql_errno(mysql_);

      LOG_INFO("SQL: {}", sql);
      LOG_ERROR("[{}] {}", lerrno, mysql_error(mysql_));

      if (HandleMysqlErrno(lerrno)) {
        return Execute(sql);
      }

      return false;
    } else {
      LOG_ERROR("[{} ms] SQL: {}", GetMsTimeDiff(time, GetMsTimeFromAppStart()), sql);
    }
  }

  return true;
}

bool CMysqlConnection::Execute(CPreparedStatementBase *stmt) {
  if (!mysql_) {
    return false;
  }

  uint32_t index = stmt->index_;

  CPreparedStatementMysql* stmt_mysql = GetPreparedStatement(index);
  DBG_ASSERT(stmt_mysql);
  stmt_mysql->stmt_ = stmt;

  stmt->BindParameters(stmt_mysql);

  SMysqlStmt* mysql_STMT = stmt_mysql->GetStmt();
  SMysqlBind* mysql_BIND = stmt_mysql->GetBind();

  uint32_t time = GetMsTimeFromAppStart();

  if (mysql_stmt_bind_param(mysql_STMT, mysql_BIND)) {
    uint32_t lerrno = mysql_errno(mysql_);
    LOG_ERROR("SQL(p): {} [{}] {}", stmt_mysql->GetQueryString().c_str(), lerrno, mysql_stmt_error(mysql_STMT));

    if (HandleMysqlErrno(lerrno)) {
      return Execute(stmt);
    }

    stmt_mysql->ClearParameters();
    return false;
  }

  if (mysql_stmt_execute(mysql_STMT)) {
    uint32_t lerrno = mysql_errno(mysql_);
    LOG_ERROR("SQL(p): {} [{}] {}", stmt_mysql->GetQueryString().c_str(), lerrno, mysql_stmt_error(mysql_STMT));

    if (HandleMysqlErrno(lerrno)) {
      return Execute(stmt);
    }

    stmt_mysql->ClearParameters();
    return false;
  }

  LOG_DEBUG("[{} ms] SQL(p): {}", GetMsTimeDiff(time, GetMsTimeFromAppStart()), stmt_mysql->GetQueryString().c_str());

  stmt_mysql->ClearParameters();
  return true;
}

CResultSet *CMysqlConnection::Query(const char *sql) {
  if (!sql) {
    return nullptr;
  }

  SMysqlResult* result = nullptr;
  SMysqlField* fields = nullptr;
  uint64_t row_count = 0;
  uint32_t field_count = 0;

  if (!_Query(sql, &result, &fields, &row_count, &field_count)) {
    return nullptr;
  }
  return new CResultSet(result, fields, row_count, field_count);
}

CPreparedResultSet *CMysqlConnection::Query(CPreparedStatementBase *stmt) {
  SMysqlResult* result = nullptr;
  uint64_t row_count = 0;
  uint32_t field_count = 0;

  if (!_Query(stmt, &result, &row_count, &field_count)) {
    return nullptr;
  }

  if (mysql_more_results(mysql_)) {
    mysql_next_result(mysql_);
  }
  return new CPreparedResultSet(stmt->stmt_->GetStmt(), result, row_count, field_count);
}

void CMysqlConnection::BeginTransaction() {
  Execute("START TRANSACTION");
}

void CMysqlConnection::RollbackTransaction() {
  Execute("ROLLBACK");
}

void CMysqlConnection::CommitTransaction() {
  Execute("COMMIT");
}

int CMysqlConnection::ExecuteTransaction(std::shared_ptr<CTransactionBase> transaction) {
  std::vector<SSqlElementData> const& queries = transaction->queries_;
  if (queries.empty()) {
    return -1;
  }

  BeginTransaction();

  for (auto iter = queries.begin(); iter != queries.end(); ++iter) {
    SSqlElementData const& data = *iter;
    switch (iter->type) {
      case kSqlElementDataType_Prepared:
        {
          CPreparedStatementBase* stmt = data.element.stmt;
          DBG_ASSERT(stmt);
          if (!Execute(stmt)) {
            LOG_WARN("Transaction aborted. {} queries not executed.", static_cast<uint32_t>(queries.size()));
            int error_code = GetLastError();
            RollbackTransaction();
            return error_code;
          }
        }
        break;
      case kSqlElementDataType_Raw:
        {
          const char* sql = data.element.query;
          DBG_ASSERT(sql);
          if (!Execute(sql)) {
            LOG_WARN("Transaction aborted. {} queries not executed.", static_cast<uint32_t>(queries.size()));
            int error_code = GetLastError();
            RollbackTransaction();
            return error_code;
          }
        }
        break;
    }
  }

  CommitTransaction();

  return 0;
}

size_t CMysqlConnection::EscapeString(char *to, const char *from, size_t length) {
  return mysql_real_escape_string(mysql_, to, from, length);
}

void CMysqlConnection::Ping() {
  mysql_ping(mysql_);
}

uint32_t CMysqlConnection::GetLastError() {
  return mysql_errno(mysql_);
}

bool CMysqlConnection::LockIfReady() {
  return mutex_.try_lock();
}

void CMysqlConnection::Unlock() {
  mutex_.unlock();
}

uint32_t CMysqlConnection::GetServerVersion() const {
  return mysql_get_server_version(mysql_);
}

CPreparedStatementMysql *CMysqlConnection::GetPreparedStatement(uint32_t index) {
  DBG_ASSERT(index < stmts_.size());
  CPreparedStatementMysql* stmt_mysql = stmts_[index].get();
  if (!stmt_mysql) {
    LOG_ERROR("Could not fetch prepared statement {} on database `{}`, connection type: {}.",
      index, conn_info_.database.c_str(), (conn_flags_ & kConnectionFlags_Async) ? "asynchronous" : "synchronous");
  }
  return stmt_mysql;
}

void CMysqlConnection::PreparedStatement(uint32_t index, const std::string &sql, EConnectionFlags flags) {
  if (!(conn_flags_ & flags)) {
    stmts_[index].reset();
    return;
  }

  MYSQL_STMT* stmt = mysql_stmt_init(mysql_);
  if (!stmt) {
    LOG_ERROR("In mysql_stmt_init() id: {}, sql: \"{}\"", index, sql.c_str());
    LOG_ERROR("{}", mysql_error(mysql_));
    prepare_error_ = true;
  } else {
    if (mysql_stmt_prepare(stmt, sql.c_str(), static_cast<unsigned long>(sql.size()))) {
      LOG_ERROR("In mysql_stmt_prepare() id: {}, sql: \"{}\"", index, sql.c_str());
      LOG_ERROR("{}", mysql_error(mysql_));
      mysql_stmt_close(stmt);
      prepare_error_ = true;
    } else {
      stmts_[index] = std::make_unique<CPreparedStatementMysql>(reinterpret_cast<SMysqlStmt*>(stmt), sql);
    }
  }
}

bool CMysqlConnection::_Query(const char *sql, SMysqlResult **result, SMysqlField **fields, uint64_t *row_count, uint32_t *field_count) {
  if (!mysql_) {
    return false;
  }

  {
    uint32_t time = GetMsTimeFromAppStart();

    if (mysql_query(mysql_, sql)) {
      uint32_t lerrno = mysql_errno(mysql_);
      LOG_INFO("SQL: {}", sql);
      LOG_ERROR("[{}] {}", lerrno, mysql_error(mysql_));

      if (HandleMysqlErrno(lerrno)) {
        return _Query(sql, result, fields, row_count, field_count);
      }

      return false;
    } else {
      LOG_DEBUG("[{} ms] SQL: %s", GetMsTimeDiff(time, GetMsTimeFromAppStart()), sql);
    }

    *result = reinterpret_cast<SMysqlResult*>(mysql_store_result(mysql_));
    *row_count = mysql_affected_rows(mysql_);
    *field_count = mysql_field_count(mysql_);
  }

  if (!*result) {
    return false;
  }

  if (!*row_count) {
    mysql_free_result(*result);
    return false;
  }

  *fields = reinterpret_cast<SMysqlField*>(mysql_fetch_fields(*result));

  return true;
}

bool CMysqlConnection::_Query(CPreparedStatementBase *stmt, SMysqlResult **result, uint64_t *row_count, uint32_t *field_count) {
  if(!mysql_) {
    return false;
  }

  uint32_t index = stmt->index_;
  CPreparedStatementMysql* stmt_mysql = GetPreparedStatement(index);
  DBG_ASSERT(stmt_mysql);
  stmt_mysql->stmt_ = stmt;

  stmt->BindParameters(stmt_mysql);

  MYSQL_STMT* mysql_STMT = stmt_mysql->GetStmt();
  MYSQL_BIND* mysql_BIND = stmt_mysql->GetBind();

  uint32_t time = GetMsTimeFromAppStart();

  if (mysql_stmt_bind_param(mysql_STMT, mysql_BIND)) {
    uint32_t lerrno = mysql_errno(mysql_);
    LOG_ERROR("SQL(p): {} [{}] {}", stmt_mysql->GetQueryString().c_str(), lerrno, mysql_stmt_error(mysql_STMT));

    if (HandleMysqlErrno(lerrno)) {
      return _Query(stmt, result, row_count, field_count);
    }

    stmt_mysql->ClearParameters();
    return false;
  }

  if (mysql_stmt_execute(mysql_STMT)) {
    uint32_t lerrno = mysql_errno(mysql_);
    LOG_ERROR("SQL(p): {} [{}] {}", stmt_mysql->GetQueryString().c_str(), lerrno, mysql_stmt_error(mysql_STMT));

    if (HandleMysqlErrno(lerrno)) {
      return _Query(stmt, result, row_count, field_count);
    }

    stmt_mysql->ClearParameters();
    return false;
  }

  LOG_DEBUG("[{} ms] SQL(p): {}", GetMsTimeDiff(time, GetMsTimeFromAppStart()), stmt_mysql->GetQueryString().c_str());

  stmt_mysql->ClearParameters();

  *result = reinterpret_cast<SMysqlResult*>(mysql_stmt_result_metadata(mysql_STMT));
  *row_count = mysql_stmt_num_rows(mysql_STMT);
  *field_count = mysql_stmt_field_count(mysql_STMT);

  return true;
}

bool CMysqlConnection::HandleMysqlErrno(uint32_t error_no, uint8_t attempts) {
  switch (error_no) {
    case CR_SERVER_GONE_ERROR:
    case CR_SERVER_LOST:
    case CR_SERVER_LOST_EXTENDED:
      {
        if (mysql_) {
          LOG_ERROR("Lost the connection to the MySQL server!");

          mysql_close(mysql_);
          mysql_ = nullptr;
        }
      }
      [[fallthrough]];
    case CR_CONN_HOST_ERROR:
      {
        LOG_INFO("Attempting to reconnect to the MySQL server...");

        reconnecting_ = true;

        uint32_t const lerrno = Open();
        if (!lerrno) {
          // Don't remove 'this' pointer unless you want to skip loading all prepared statements...
          if (!this->PrepareStatements()) {
            LOG_FATAL("Could not re-prepare statements!");

            std::this_thread::sleep_for(RSeconds(10));
            std::abort();
          }

          LOG_INFO("Successfully reconnected to {} @{}:{} ({}).",
            conn_info_.database.c_str(), conn_info_.host.c_str(), conn_info_.port_or_socket.c_str(),
            (conn_flags_ & kConnectionFlags_Async) ? "asynchronous" : "synchronous");

          reconnecting_ = false;
          return true;
        }

        if ((--attempts) == 0) {
          // Shut down the server when the mysql server isn't reachable for some time
          LOG_FATAL("Failed to reconnect to the MySQL server, terminating the server to prevent data corruption!");

          // We could also initiate a shutdown through using std::raise(SIGTERM)
          std::this_thread::sleep_for(RSeconds(10));
          std::abort();
        } else {
          // 尝试重新连接的尝试可能会向我们抛出2006 为了防止疯狂的递归调用
          std::this_thread::sleep_for(RSeconds(3)); // Sleep 3 seconds
          return HandleMysqlErrno(lerrno, attempts); // Call self (recursive)
        }
      }
      [[fallthrough]];
    case ER_LOCK_DEADLOCK:
      return false;    // Implemented in TransactionTask::Execute and DatabaseWorkerPool<T>::DirectCommitTransaction
      // 查询相关错误 跳过查询
    case ER_WRONG_VALUE_COUNT:
    case ER_DUP_ENTRY:
      return false;

      // Outdated table or database structure - terminate core
    case ER_BAD_FIELD_ERROR:
    case ER_NO_SUCH_TABLE:
      LOG_ERROR("Your database structure is not up to date. Please make sure you've executed all queries in the sql/updates folders.");
      std::this_thread::sleep_for(RSeconds(10));
      std::abort();
      return false;
    case ER_PARSE_ERROR:
      LOG_ERROR("Error while parsing SQL. Core fix required.");
      std::this_thread::sleep_for(RSeconds(10));
      std::abort();
      return false;
    default:
      LOG_ERROR("Unhandled MySQL errno {}. Unexpected behaviour possible.", error_no);
      return false;
  }
}

} // !namespace database
} // !namespace afcore